/*
 * Copyright (c) 2005 Chris Richardson
 * 
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 * 
 *      http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.chrisrichardson.ormunit.hibernate;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.orm.hibernate3.LocalSessionFactoryBean;

/*
 * Resets the database by creating the schema using LocalSessionFactoryBean
 * dropDatabaseSchema and createDatabaseSchema
 */

public class ResetDatabaseByRecreatingSchemaStrategy implements
		DatabaseResetStrategy, ApplicationContextAware {

	private Log logger = LogFactory.getLog(getClass());
	
	private ApplicationContext applicationContext;

	private JdbcTemplate jdbcTemplate;

	private List<String> tablesToPreserve = Collections.EMPTY_LIST;

	public void setApplicationContext(ApplicationContext applicationContext)
			throws BeansException {
		this.applicationContext = applicationContext;
	}
	
	

	public void setTablesToPreserve(List<String> tablesToPreserve) {
		this.tablesToPreserve = tablesToPreserve;
	}



	public void setDataSource(DataSource dataSource) {
		this.jdbcTemplate = new JdbcTemplate(dataSource);
	}

	public void resetDatabase() {
		// TODO consider using DBUnit for this
		LocalSessionFactoryBean localSessionFactoryBean = LocalSessionFactoryBeanUtil
				.getLocalSessionFactoryBean(applicationContext);
		LinkedHashMap<String, List<Object[]>> preservedData = new LinkedHashMap<String, List<Object[]> >();
		
		for (String tableToPreserve  : tablesToPreserve) {
			List<Object[]> preserve = getDataForTable(tableToPreserve);
			preservedData.put(tableToPreserve, preserve);
		}

		localSessionFactoryBean.dropDatabaseSchema();
		localSessionFactoryBean.createDatabaseSchema();

		for (Map.Entry<String,List<Object[]>> entry : preservedData.entrySet()) {
			restoreTable(entry.getKey(), entry.getValue());
		}
	}

	private void restoreTable(String tableToPreserve, List<Object[]> preserve) {
		if (preserve != null) {
			jdbcTemplate.update("DELETE FROM " + tableToPreserve);
			int n = preserve.get(0).length;
			StringBuffer sb = new StringBuffer();
			for (int i = 0; i < n; i++) {
				if (sb.length() > 0) {
					sb.append(",");
				}
				sb.append("?");
			}
			String commas = sb.toString();
			for (Object[] objects : preserve) {
				jdbcTemplate.update("INSERT INTO " + tableToPreserve
						+ " values(" + commas + ")", objects);
			}
		}
	}

	private List<Object[]> getDataForTable(String tableToPreserve) {
		return jdbcTemplate.query("select * from " + tableToPreserve,
				new RowMapper() {

					public Object mapRow(ResultSet rs, int rowNum)
							throws SQLException {
						int n = rs.getMetaData().getColumnCount();
						Object[] result = new Object[n];
						for (int i = 1; i <= n; i++)
							result[i - 1] = rs.getObject(i);
						return result;
					}
				});
	}

}
